import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from .nudging import *


class PCLoss(nn.Module):
    def __init__(self, num_classes, scale):
        super(PCLoss, self).__init__()
        self.soft_plus = nn.Softplus()
        self.label = torch.LongTensor([i for i in range(num_classes)]).cuda()
        self.scale = scale

    def forward(self, feature, target, proxy):
        '''
        feature: (N, dim)
        proxy: (C, dim)
        '''

        feature = F.normalize(feature, p=2, dim=1)
        pred = F.linear(feature, F.normalize(proxy, p=2, dim=1))

        label = (self.label.unsqueeze(1) == target.unsqueeze(0))

        # print(target)

        pred_p = torch.masked_select(pred, label.transpose(1, 0))  # (N)   positive pair
        pred_p = pred_p.unsqueeze(1)

        pred_n = torch.masked_select(pred, ~label.transpose(1, 0)).view(feature.size(0),-1)  # (N, C-1) negative pair of anchor and proxy

        feature = torch.matmul(feature, feature.transpose(1, 0))
        label_matrix = target.unsqueeze(1) == target.unsqueeze(0)

        feature = feature * ~label_matrix
        feature = feature.masked_fill(feature < 1e-6, -np.inf)

        logits = torch.cat([pred_p, pred_n, feature], dim=1)
        label = torch.zeros(logits.size(0), dtype=torch.long).cuda()
        loss = F.nll_loss(F.log_softmax(self.scale * logits, dim=1), label)
        return loss


class PCALoss(nn.Module):
    def __init__(self, num_classes, scale):
        super(PCALoss, self).__init__()
        self.soft_plus = nn.Softplus()
        self.label = torch.LongTensor([i for i in range(num_classes)]).cuda()
        self.scale = scale

    def forward(self, feature, new_feature, target, proxy, Mproxy, BaseProxy, pi,mweight=1):
        '''
        feature: (N, dim)
        proxy: (C, dim)
        Mproxy: (C, dim)
        '''

        pred = F.linear(F.normalize(feature, p=2, dim=1), F.normalize(proxy, p=2, dim=1))  # (N, C)  similarity between sample and proxy
        Mpred = F.linear(F.normalize(feature, p=2, dim=1), F.normalize(Mproxy, p=2, dim=1))  # (N, C)  similarity between sample and old proxy
        Bpred = F.linear(F.normalize(feature, p=2, dim=1), F.normalize(BaseProxy, p=2, dim=1))

        NB_pred = F.linear(F.normalize(new_feature, p=2, dim=1), F.normalize(BaseProxy, p=2, dim=1))


        # NB_pred = NB_pred.transpose(1,0) * pi
        pi_new = pi.transpose(1,0)


        label_a = (self.label.unsqueeze(1) == target.unsqueeze(0))


        label_b = torch.zeros(Mpred.size()[0],Mpred.size()[1]).bool().cuda()

        label = torch.cat([label_b,label_a],dim=1)

        pi_one = torch.ones(Bpred.size(0)-pi_new.size(0) ,Bpred.size(1)).cuda()


        pi_all = torch.cat([pi_one, pi_new],dim=0)

        # pred_p = torch.masked_select(pred, label.transpose(1, 0))  # (N)   positive pair
        # pred_p = pred_p.unsqueeze(1)
        pred_p = pred * label
        pred_n = pred * ~label

        # print(torch.masked_select(pred, ~label.transpose(1, 0)).shape)
        # pred_n = torch.masked_select(pred, ~label.transpose(1, 0)).view(feature.size(0), -1)  # (N, C-1) negative pair of anchor and proxy
        # pred_n = torch.masked_select(pred, ~label.transpose(1, 0))

        # Mpred_p = torch.masked_select(Mpred, label.transpose(1, 0))
        # Mpred_p = Mpred_p.unsqueeze(1)
        # Mpred_n = torch.masked_select(Mpred, ~label.transpose(1, 0)).view(feature.size(0), -1)
        Mpred_n = Mpred

        feature = torch.matmul(feature, feature.transpose(1, 0))  # (N, N)  sample wise similarity

        label_matrix = label
        # print(label_matrix)

        feature = feature * ~label_matrix
        feature = feature.masked_fill(feature < 1e-6, -np.inf)
        # print(feature.shape)

        # loss = -torch.log(
        #     (torch.exp(self.scale * pred_p.squeeze()) + mweight * torch.exp(self.scale * Mpred_p.squeeze())) /
        #     (torch.exp(self.scale * pred_p.squeeze()) + mweight * torch.exp(self.scale * Mpred_p.squeeze()) +
        #      torch.exp(self.scale * pred_n).sum(dim=1) + mweight * torch.exp(self.scale * Mpred_n).sum(
        #                 dim=1) + torch.exp(self.scale * feature).sum(dim=1))).mean()


        # loss = -torch.log(
        #     (torch.exp(self.scale * pred_p.squeeze())) /
        #     (torch.exp(self.scale * pred_p.squeeze())  + mweight *torch.exp(self.scale * pred_n).sum(dim=1) + torch.exp(self.scale * feature).sum(dim=1)+ mweight * torch.exp(self.scale * Mpred_n).sum(dim=1) +mweight * torch.exp(self.scale * Bpred).sum(dim=1)
        # )
        # ).mean()



        loss = -torch.log(
            (torch.exp(self.scale * pred_p.squeeze())) /
            (torch.exp(self.scale * pred_p.squeeze()) +  torch.exp(self.scale * pred_n).sum(dim=1) + mweight * torch.exp(self.scale * Bpred * pi_all).sum(dim=1)

             )
        ).mean()


        # loss = -torch.log(
        #     (torch.exp(self.scale * pred_p.squeeze())) /
        #     (torch.exp(self.scale * pred_p.squeeze())  +
        #      mweight * torch.exp(self.scale * Mpred_n).sum(dim=1) +
        #      torch.exp(self.scale * feature).sum(dim=1))
        # ).mean()


        return loss

# class PCALoss(nn.Module):
#     def __init__(self, num_classes, scale):
#         super(PCALoss, self).__init__()
#         self.soft_plus = nn.Softplus()
#         self.label = torch.LongTensor([i for i in range(num_classes)]).cuda()
#         self.scale = scale
#
#     def forward(self, feature, new_feature, target, proxy, Mproxy, BaseProxy, pi,mweight=1):
#         '''
#         feature: (N, dim)
#         proxy: (C, dim)
#         Mproxy: (C, dim)
#         '''
#
#         pred = F.linear(F.normalize(feature, p=2, dim=1), F.normalize(proxy, p=2, dim=1))  # (N, C)  similarity between sample and proxy
#         Mpred = F.linear(F.normalize(feature, p=2, dim=1), F.normalize(Mproxy, p=2, dim=1))  # (N, C)  similarity between sample and old proxy
#         Bpred = F.linear(F.normalize(feature, p=2, dim=1), F.normalize(BaseProxy, p=2, dim=1))
#
#         NB_pred = F.linear(F.normalize(new_feature, p=2, dim=1), F.normalize(BaseProxy, p=2, dim=1))
#
#
#         # NB_pred = NB_pred.transpose(1,0) * pi
#         pi_new = pi.transpose(1,0)
#
#
#         label_a = (self.label.unsqueeze(1) == target.unsqueeze(0))
#
#
#         label_b = torch.zeros(Mpred.size()[0],Mpred.size()[1]).bool().cuda()
#
#         label = torch.cat([label_b,label_a],dim=1)
#
#         pi_one = torch.ones(Mpred.size(0)-pi_new.size(0) ,Mpred.size(1)).cuda()
#
#
#         pi_all = torch.cat([pi_one, pi_new],dim=0)
#
#         # pred_p = torch.masked_select(pred, label.transpose(1, 0))  # (N)   positive pair
#         # pred_p = pred_p.unsqueeze(1)
#         pred_p = pred * label
#         pred_n = pred * ~label
#
#         # print(torch.masked_select(pred, ~label.transpose(1, 0)).shape)
#         # pred_n = torch.masked_select(pred, ~label.transpose(1, 0)).view(feature.size(0), -1)  # (N, C-1) negative pair of anchor and proxy
#         # pred_n = torch.masked_select(pred, ~label.transpose(1, 0))
#
#         # Mpred_p = torch.masked_select(Mpred, label.transpose(1, 0))
#         # Mpred_p = Mpred_p.unsqueeze(1)
#         # Mpred_n = torch.masked_select(Mpred, ~label.transpose(1, 0)).view(feature.size(0), -1)
#         Mpred_n = Mpred
#
#         feature = torch.matmul(feature, feature.transpose(1, 0))  # (N, N)  sample wise similarity
#
#         label_matrix = label
#         # print(label_matrix)
#
#         feature = feature * ~label_matrix
#         feature = feature.masked_fill(feature < 1e-6, -np.inf)
#         # print(feature.shape)
#
#         # loss = -torch.log(
#         #     (torch.exp(self.scale * pred_p.squeeze()) + mweight * torch.exp(self.scale * Mpred_p.squeeze())) /
#         #     (torch.exp(self.scale * pred_p.squeeze()) + mweight * torch.exp(self.scale * Mpred_p.squeeze()) +
#         #      torch.exp(self.scale * pred_n).sum(dim=1) + mweight * torch.exp(self.scale * Mpred_n).sum(
#         #                 dim=1) + torch.exp(self.scale * feature).sum(dim=1))).mean()
#
#
#         # loss = -torch.log(
#         #     (torch.exp(self.scale * pred_p.squeeze())) /
#         #     (torch.exp(self.scale * pred_p.squeeze())  + mweight *torch.exp(self.scale * pred_n).sum(dim=1) + torch.exp(self.scale * feature).sum(dim=1)+ mweight * torch.exp(self.scale * Mpred_n).sum(dim=1) +mweight * torch.exp(self.scale * Bpred).sum(dim=1)
#         # )
#         # ).mean()
#
#
#
#         loss = -torch.log(
#             (torch.exp(self.scale * pred_p.squeeze())) /
#             (torch.exp(self.scale * pred_p.squeeze()) +  torch.exp(self.scale * pred_n).sum(dim=1) + mweight * torch.exp(self.scale * Mpred * pi_all).sum(dim=1)
#
#              )
#         ).mean()
#
#
#         # loss = -torch.log(
#         #     (torch.exp(self.scale * pred_p.squeeze())) /
#         #     (torch.exp(self.scale * pred_p.squeeze())  +
#         #      mweight * torch.exp(self.scale * Mpred_n).sum(dim=1) +
#         #      torch.exp(self.scale * feature).sum(dim=1))
#         # ).mean()
#
#
#         return loss


class PCALoss_v2(nn.Module):
    def __init__(self, num_classes, scale):
        super(PCALoss_v2, self).__init__()
        self.soft_plus = nn.Softplus()
        self.label = torch.LongTensor([i for i in range(num_classes)]).cuda()
        self.scale = scale
        self.mask = torch.nn.Parameter(torch.triu(torch.ones([num_classes, num_classes], dtype=torch.uint8), diagonal=1),
                                            requires_grad=False).cuda()  # .to(self._device)

    def forward(self, feature, target, proxy):
        '''
        feature: (N, dim)
        proxy: (C, dim)
        Mproxy: (C, dim)
        '''
        feature = F.normalize(feature, p=2, dim=1)
        pred = F.linear(feature, F.normalize(proxy, p=2, dim=1))  # (N, C)  similarity between sample and proxy
        # Mpred = F.linear(feature, F.normalize(Mproxy, p=2, dim=1))  # (N, C)  similarity between sample and old proxy

        label = (self.label.unsqueeze(1) == target.unsqueeze(0))
        pred_p = torch.masked_select(pred, label.transpose(1, 0))  # (N)   positive pair
        pred_p = pred_p.unsqueeze(1)
        pred_n = torch.masked_select(pred, ~label.transpose(1, 0)).view(feature.size(0),
                                                                        -1)  # (N, C-1) negative pair of anchor and proxy
        # Mpred_p = torch.masked_select(Mpred, label.transpose(1, 0))
        # Mpred_p = Mpred_p.unsqueeze(1)
        # Mpred_n = torch.masked_select(Mpred, ~label.transpose(1, 0)).view(feature.size(0), -1)

        feature = torch.matmul(feature, feature.transpose(1, 0))  # (N, N)  sample wise similarity
        label_matrix = target.unsqueeze(1) == target.unsqueeze(0)

        feature = feature * ~label_matrix
        feature = feature.masked_fill(feature < 1e-6, -np.inf)

        proxy_n = F.normalize(proxy, p=2, dim=1)
        proxy_c = torch.matmul(proxy_n, proxy_n.transpose(1, 0))
        proxy_c = proxy_c * self.mask
        proxy_c = proxy_c.masked_fill(proxy_c < 1e-6, -np.inf)


        loss = -torch.log(
            (torch.exp(self.scale * pred_p.squeeze())) /

            (torch.exp(self.scale * pred_p.squeeze())  +
             torch.exp(self.scale * pred_n).sum(dim=1) + torch.exp(self.scale * feature).sum(dim=1)+torch.exp(self.scale * proxy_c).sum(dim=1))
        ).mean()
        return loss




class CorrelationLoss(torch.nn.Module):
    def __init__(self, num_ways, act="doubleexp", act_exp=4):
        super( CorrelationLoss, self).__init__()

        self.act_exp = act_exp
        self.tnhscaleP = torch.nn.Parameter(torch.ones([1], dtype=torch.float32) * 1, requires_grad=False).cuda()  # init to 1.5
        self.mask = torch.nn.Parameter(torch.triu(torch.ones([num_ways, num_ways], dtype=torch.uint8), diagonal=1),
                                   requires_grad=False).cuda()   # .to(self._device)
        self.mask_sum = torch.sum(self.mask)
        self.cos = torch.nn.CosineSimilarity()

        if act == "exp":
            self.act = exp_loss(act_exp)
        elif act == "doubleexp":
            self.act = doubleexp_loss(act_exp)
        else:
            raise ValueError("Non-valid nudging activation function. Got {:}".format(act))

    # def init_params(self, initial_prototypes):
    #
    #     self.prod_vecs = t.nn.Parameter(initial_prototypes)

    def forward(self, initial_prototypes):
        # compute cross-correlation loss
        prod_vecs = t.tanh(self.tnhscaleP * initial_prototypes)
        norm_prod_vecs = f.normalize(prod_vecs, p=2, dim=1)
        prod_sims = t.tensordot(norm_prod_vecs, t.transpose(norm_prod_vecs, 0, 1), dims=1) * self.mask
        prod_sim_loss = self.act(prod_sims)
        prod_sim_loss = t.sum(prod_sim_loss) / self.mask_sum
        return prod_sim_loss

